#include <stdio.h>
#include <string.h>
#include <time.h>
#include <cctype>
#include <filesystem>
#include <functional>
#include <iostream>
#include <map>
#include <random>
#include <regex>
#include <sstream>
#include <string>
#include <vector>

// #include "preprocessing.hpp"
#include "stable-diffusion.h"

#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
#include "stb_image.h"

#define STB_IMAGE_WRITE_IMPLEMENTATION
#define STB_IMAGE_WRITE_STATIC
#include "stb_image_write.h"

#define STB_IMAGE_RESIZE_IMPLEMENTATION
#define STB_IMAGE_RESIZE_STATIC
#include "stb_image_resize.h"

#include "avi_writer.h"

#if defined(_WIN32)
#define NOMINMAX
#include <windows.h>
#endif  // _WIN32

#define SAFE_STR(s) ((s) ? (s) : "")
#define BOOL_STR(b) ((b) ? "true" : "false")

namespace fs = std::filesystem;

const char* modes_str[] = {
    "img_gen",
    "vid_gen",
    "convert",
    "upscale",
};
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale"

const char* previews_str[] = {
    "none",
    "proj",
    "tae",
    "vae",
};

enum SDMode {
    IMG_GEN,
    VID_GEN,
    CONVERT,
    UPSCALE,
    MODE_COUNT
};

#if defined(_WIN32)
static std::string utf16_to_utf8(const std::wstring& wstr) {
    if (wstr.empty())
        return {};
    int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(),
                                          nullptr, 0, nullptr, nullptr);
    if (size_needed <= 0)
        throw std::runtime_error("UTF-16 to UTF-8 conversion failed");

    std::string utf8(size_needed, 0);
    WideCharToMultiByte(CP_UTF8, 0, wstr.data(), (int)wstr.size(),
                        (char*)utf8.data(), size_needed, nullptr, nullptr);
    return utf8;
}

static std::string argv_to_utf8(int index, const char** argv) {
    int argc;
    wchar_t** argv_w = CommandLineToArgvW(GetCommandLineW(), &argc);
    if (!argv_w)
        throw std::runtime_error("Failed to parse command line");

    std::string result;
    if (index < argc) {
        result = utf16_to_utf8(argv_w[index]);
    }
    LocalFree(argv_w);
    return result;
}

#else  // Linux / macOS
static std::string argv_to_utf8(int index, const char** argv) {
    return std::string(argv[index]);
}

#endif

struct StringOption {
    std::string short_name;
    std::string long_name;
    std::string desc;
    std::string* target;
};

struct IntOption {
    std::string short_name;
    std::string long_name;
    std::string desc;
    int* target;
};

struct FloatOption {
    std::string short_name;
    std::string long_name;
    std::string desc;
    float* target;
};

struct BoolOption {
    std::string short_name;
    std::string long_name;
    std::string desc;
    bool keep_true;
    bool* target;
};

struct ManualOption {
    std::string short_name;
    std::string long_name;
    std::string desc;
    std::function<int(int argc, const char** argv, int index)> cb;
};

struct ArgOptions {
    std::vector<StringOption> string_options;
    std::vector<IntOption> int_options;
    std::vector<FloatOption> float_options;
    std::vector<BoolOption> bool_options;
    std::vector<ManualOption> manual_options;

    static std::string wrap_text(const std::string& text, size_t width, size_t indent) {
        std::ostringstream oss;
        size_t line_len = 0;
        size_t pos      = 0;

        while (pos < text.size()) {
            // Preserve manual newlines
            if (text[pos] == '\n') {
                oss << '\n'
                    << std::string(indent, ' ');
                line_len = indent;
                ++pos;
                continue;
            }

            // Add the character
            oss << text[pos];
            ++line_len;
            ++pos;

            // If the current line exceeds width, try to break at the last space
            if (line_len >= width) {
                std::string current = oss.str();
                size_t back         = current.size();

                // Find the last space (for a clean break)
                while (back > 0 && current[back - 1] != ' ' && current[back - 1] != '\n')
                    --back;

                // If found a space to break on
                if (back > 0 && current[back - 1] != '\n') {
                    std::string before = current.substr(0, back - 1);
                    std::string after  = current.substr(back);
                    oss.str("");
                    oss.clear();
                    oss << before << "\n"
                        << std::string(indent, ' ') << after;
                } else {
                    // If no space found, just break at width
                    oss << "\n"
                        << std::string(indent, ' ');
                }
                line_len = indent;
            }
        }

        return oss.str();
    }

    void print() const {
        constexpr size_t max_line_width = 120;

        struct Entry {
            std::string names;
            std::string desc;
        };
        std::vector<Entry> entries;

        auto add_entry = [&](const std::string& s, const std::string& l,
                             const std::string& desc, const std::string& hint = "") {
            std::ostringstream ss;
            if (!s.empty())
                ss << s;
            if (!s.empty() && !l.empty())
                ss << ", ";
            if (!l.empty())
                ss << l;
            if (!hint.empty())
                ss << " " << hint;
            entries.push_back({ss.str(), desc});
        };

        for (auto& o : string_options)
            add_entry(o.short_name, o.long_name, o.desc, "<string>");
        for (auto& o : int_options)
            add_entry(o.short_name, o.long_name, o.desc, "<int>");
        for (auto& o : float_options)
            add_entry(o.short_name, o.long_name, o.desc, "<float>");
        for (auto& o : bool_options)
            add_entry(o.short_name, o.long_name, o.desc, "");
        for (auto& o : manual_options)
            add_entry(o.short_name, o.long_name, o.desc);

        size_t max_name_width = 0;
        for (auto& e : entries)
            max_name_width = std::max(max_name_width, e.names.size());

        for (auto& e : entries) {
            size_t indent            = 2 + max_name_width + 4;
            size_t desc_width        = (max_line_width > indent ? max_line_width - indent : 40);
            std::string wrapped_desc = wrap_text(e.desc, max_line_width, indent);
            std::cout << "  " << std::left << std::setw(static_cast<int>(max_name_width) + 4)
                      << e.names << wrapped_desc << "\n";
        }
    }
};

bool parse_options(int argc, const char** argv, const std::vector<ArgOptions>& options_list) {
    bool invalid_arg = false;
    std::string arg;

    auto match_and_apply = [&](auto& opts, auto&& apply_fn) -> bool {
        for (auto& option : opts) {
            if ((option.short_name.size() > 0 && arg == option.short_name) ||
                (option.long_name.size() > 0 && arg == option.long_name)) {
                apply_fn(option);
                return true;
            }
        }
        return false;
    };

    for (int i = 1; i < argc; i++) {
        arg            = argv[i];
        bool found_arg = false;

        for (auto& options : options_list) {
            if (match_and_apply(options.string_options, [&](auto& option) {
                    if (++i >= argc) {
                        invalid_arg = true;
                        return;
                    }
                    *option.target = argv_to_utf8(i, argv);
                    found_arg      = true;
                }))
                break;

            if (match_and_apply(options.int_options, [&](auto& option) {
                    if (++i >= argc) {
                        invalid_arg = true;
                        return;
                    }
                    *option.target = std::stoi(argv[i]);
                    found_arg      = true;
                }))
                break;

            if (match_and_apply(options.float_options, [&](auto& option) {
                    if (++i >= argc) {
                        invalid_arg = true;
                        return;
                    }
                    *option.target = std::stof(argv[i]);
                    found_arg      = true;
                }))
                break;

            if (match_and_apply(options.bool_options, [&](auto& option) {
                    *option.target = option.keep_true ? true : false;
                    found_arg      = true;
                }))
                break;

            if (match_and_apply(options.manual_options, [&](auto& option) {
                    int ret = option.cb(argc, argv, i);
                    if (ret < 0) {
                        invalid_arg = true;
                        return;
                    }
                    i += ret;
                    found_arg = true;
                }))
                break;
        }

        if (invalid_arg) {
            fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
            return false;
        }
        if (!found_arg) {
            fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
            return false;
        }
    }

    return true;
}

struct SDCliParams {
    SDMode mode             = IMG_GEN;
    std::string output_path = "output.png";

    bool verbose          = false;
    bool canny_preprocess = false;

    preview_t preview_method = PREVIEW_NONE;
    int preview_interval     = 1;
    std::string preview_path = "preview.png";
    int preview_fps          = 16;
    bool taesd_preview       = false;
    bool preview_noisy       = false;
    bool color               = false;

    bool normal_exit = false;

    ArgOptions get_options() {
        ArgOptions options;

        options.string_options = {
            {"-o",
             "--output",
             "path to write result image to (default: ./output.png)",
             &output_path},
            {"",
             "--preview-path",
             "path to write preview image to (default: ./preview.png)",
             &preview_path},
        };

        options.int_options = {
            {"",
             "--preview-interval",
             "interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)",
             &preview_interval},
        };

        options.bool_options = {
            {"",
             "--canny",
             "apply canny preprocessor (edge detection)",
             true, &canny_preprocess},
            {"-v",
             "--verbose",
             "print extra info",
             true, &verbose},
            {"",
             "--color",
             "colors the logging tags according to level",
             true, &color},
            {"",
             "--taesd-preview-only",
             std::string("prevents usage of taesd for decoding the final image. (for use with --preview ") + previews_str[PREVIEW_TAE] + ")",
             true, &taesd_preview},
            {"",
             "--preview-noisy",
             "enables previewing noisy inputs of the models rather than the denoised outputs",
             true, &preview_noisy},

        };

        auto on_mode_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            const char* mode_c_str = argv[index];
            if (mode_c_str != nullptr) {
                int mode_found = -1;
                for (int i = 0; i < MODE_COUNT; i++) {
                    if (!strcmp(mode_c_str, modes_str[i])) {
                        mode_found = i;
                    }
                }
                if (mode_found == -1) {
                    fprintf(stderr,
                            "error: invalid mode %s, must be one of [%s]\n",
                            mode_c_str, SD_ALL_MODES_STR);
                    exit(1);
                }
                mode = (SDMode)mode_found;
            }
            return 1;
        };

        auto on_preview_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            const char* preview = argv[index];
            int preview_found   = -1;
            for (int m = 0; m < PREVIEW_COUNT; m++) {
                if (!strcmp(preview, previews_str[m])) {
                    preview_found = m;
                }
            }
            if (preview_found == -1) {
                fprintf(stderr, "error: preview method %s\n",
                        preview);
                return -1;
            }
            preview_method = (preview_t)preview_found;
            return 1;
        };

        auto on_help_arg = [&](int argc, const char** argv, int index) {
            normal_exit = true;
            return -1;
        };

        options.manual_options = {
            {"-M",
             "--mode",
             "run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen",
             on_mode_arg},
            {"",
             "--preview",
             std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")",
             on_preview_arg},
            {"-h",
             "--help",
             "show this help message and exit",
             on_help_arg},
        };

        return options;
    };

    bool process_and_check() {
        if (output_path.length() == 0) {
            fprintf(stderr, "error: the following arguments are required: output_path\n");
            return false;
        }

        if (mode == CONVERT) {
            if (output_path == "output.png") {
                output_path = "output.gguf";
            }
        }
        return true;
    }

    std::string to_string() const {
        std::ostringstream oss;
        oss << "SDCliParams {\n"
            << "  mode: " << modes_str[mode] << ",\n"
            << "  output_path: \"" << output_path << "\",\n"
            << "  verbose: " << (verbose ? "true" : "false") << ",\n"
            << "  color: " << (color ? "true" : "false") << ",\n"
            << "  canny_preprocess: " << (canny_preprocess ? "true" : "false") << ",\n"
            << "  preview_method: " << previews_str[preview_method] << ",\n"
            << "  preview_interval: " << preview_interval << ",\n"
            << "  preview_path: \"" << preview_path << "\",\n"
            << "  preview_fps: " << preview_fps << ",\n"
            << "  taesd_preview: " << (taesd_preview ? "true" : "false") << ",\n"
            << "  preview_noisy: " << (preview_noisy ? "true" : "false") << "\n"
            << "}";
        return oss.str();
    }
};

struct SDContextParams {
    int n_threads = -1;
    std::string model_path;
    std::string clip_l_path;
    std::string clip_g_path;
    std::string clip_vision_path;
    std::string t5xxl_path;
    std::string llm_path;
    std::string llm_vision_path;
    std::string diffusion_model_path;
    std::string high_noise_diffusion_model_path;
    std::string vae_path;
    std::string taesd_path;
    std::string esrgan_path;
    std::string control_net_path;
    std::string embedding_dir;
    std::string photo_maker_path;
    sd_type_t wtype = SD_TYPE_COUNT;
    std::string tensor_type_rules;
    std::string lora_model_dir;

    std::map<std::string, std::string> embedding_map;
    std::vector<sd_embedding_t> embedding_array;

    rng_type_t rng_type         = CUDA_RNG;
    rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
    bool offload_params_to_cpu  = false;
    bool control_net_cpu        = false;
    bool clip_on_cpu            = false;
    bool vae_on_cpu             = false;
    bool diffusion_flash_attn   = false;
    bool diffusion_conv_direct  = false;
    bool vae_conv_direct        = false;

    bool chroma_use_dit_mask = true;
    bool chroma_use_t5_mask  = false;
    int chroma_t5_mask_pad   = 1;

    prediction_t prediction           = PREDICTION_COUNT;
    lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;

    sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
    bool force_sdxl_vae_conv_scale       = false;

    float flow_shift = INFINITY;

    ArgOptions get_options() {
        ArgOptions options;
        options.string_options = {
            {"-m",
             "--model",
             "path to full model",
             &model_path},
            {"",
             "--clip_l",
             "path to the clip-l text encoder", &clip_l_path},
            {"", "--clip_g",
             "path to the clip-g text encoder",
             &clip_g_path},
            {"",
             "--clip_vision",
             "path to the clip-vision encoder",
             &clip_vision_path},
            {"",
             "--t5xxl",
             "path to the t5xxl text encoder",
             &t5xxl_path},
            {"",
             "--llm",
             "path to the llm text encoder. For example: (qwenvl2.5 for qwen-image, mistral-small3.2 for flux2, ...)",
             &llm_path},
            {"",
             "--llm_vision",
             "path to the llm vit",
             &llm_vision_path},
            {"",
             "--qwen2vl",
             "alias of --llm. Deprecated.",
             &llm_path},
            {"",
             "--qwen2vl_vision",
             "alias of --llm_vision. Deprecated.",
             &llm_vision_path},
            {"",
             "--diffusion-model",
             "path to the standalone diffusion model",
             &diffusion_model_path},
            {"",
             "--high-noise-diffusion-model",
             "path to the standalone high noise diffusion model",
             &high_noise_diffusion_model_path},
            {"",
             "--vae",
             "path to standalone vae model",
             &vae_path},
            {"",
             "--taesd",
             "path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)",
             &taesd_path},
            {"",
             "--control-net",
             "path to control net model",
             &control_net_path},
            {"",
             "--embd-dir",
             "embeddings directory",
             &embedding_dir},
            {"",
             "--lora-model-dir",
             "lora model directory",
             &lora_model_dir},

            {"",
             "--tensor-type-rules",
             "weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")",
             &tensor_type_rules},
            {"",
             "--photo-maker",
             "path to PHOTOMAKER model",
             &photo_maker_path},
            {"",
             "--upscale-model",
             "path to esrgan model.",
             &esrgan_path},
        };

        options.int_options = {
            {"-t",
             "--threads",
             "number of threads to use during computation (default: -1). "
             "If threads <= 0, then threads will be set to the number of CPU physical cores",
             &n_threads},
            {"",
             "--chroma-t5-mask-pad",
             "t5 mask pad size of chroma",
             &chroma_t5_mask_pad},
        };

        options.float_options = {
            {"",
             "--vae-tile-overlap",
             "tile overlap for vae tiling, in fraction of tile size (default: 0.5)",
             &vae_tiling_params.target_overlap},
            {"",
             "--flow-shift",
             "shift value for Flow models like SD3.x or WAN (default: auto)",
             &flow_shift},
        };

        options.bool_options = {
            {"",
             "--vae-tiling",
             "process vae in tiles to reduce memory usage",
             true, &vae_tiling_params.enabled},
            {"",
             "--force-sdxl-vae-conv-scale",
             "force use of conv scale on sdxl vae",
             true, &force_sdxl_vae_conv_scale},
            {"",
             "--offload-to-cpu",
             "place the weights in RAM to save VRAM, and automatically load them into VRAM when needed",
             true, &offload_params_to_cpu},
            {"",
             "--control-net-cpu",
             "keep controlnet in cpu (for low vram)",
             true, &control_net_cpu},
            {"",
             "--clip-on-cpu",
             "keep clip in cpu (for low vram)",
             true, &clip_on_cpu},
            {"",
             "--vae-on-cpu",
             "keep vae in cpu (for low vram)",
             true, &vae_on_cpu},
            {"",
             "--diffusion-fa",
             "use flash attention in the diffusion model",
             true, &diffusion_flash_attn},
            {"",
             "--diffusion-conv-direct",
             "use ggml_conv2d_direct in the diffusion model",
             true, &diffusion_conv_direct},
            {"",
             "--vae-conv-direct",
             "use ggml_conv2d_direct in the vae model",
             true, &vae_conv_direct},
            {"",
             "--chroma-disable-dit-mask",
             "disable dit mask for chroma",
             false, &chroma_use_dit_mask},
            {"",
             "--chroma-enable-t5-mask",
             "enable t5 mask for chroma",
             true, &chroma_use_t5_mask},
        };

        auto on_type_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            const char* arg = argv[index];
            wtype           = str_to_sd_type(arg);
            if (wtype == SD_TYPE_COUNT) {
                fprintf(stderr, "error: invalid weight format %s\n",
                        arg);
                return -1;
            }
            return 1;
        };

        auto on_rng_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            const char* arg = argv[index];
            rng_type        = str_to_rng_type(arg);
            if (rng_type == RNG_TYPE_COUNT) {
                fprintf(stderr, "error: invalid rng type %s\n",
                        arg);
                return -1;
            }
            return 1;
        };

        auto on_sampler_rng_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            const char* arg  = argv[index];
            sampler_rng_type = str_to_rng_type(arg);
            if (sampler_rng_type == RNG_TYPE_COUNT) {
                fprintf(stderr, "error: invalid sampler rng type %s\n",
                        arg);
                return -1;
            }
            return 1;
        };

        auto on_prediction_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            const char* arg = argv[index];
            prediction      = str_to_prediction(arg);
            if (prediction == PREDICTION_COUNT) {
                fprintf(stderr, "error: invalid prediction type %s\n",
                        arg);
                return -1;
            }
            return 1;
        };

        auto on_lora_apply_mode_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            const char* arg = argv[index];
            lora_apply_mode = str_to_lora_apply_mode(arg);
            if (lora_apply_mode == LORA_APPLY_MODE_COUNT) {
                fprintf(stderr, "error: invalid lora apply model %s\n",
                        arg);
                return -1;
            }
            return 1;
        };

        auto on_tile_size_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            std::string tile_size_str = argv[index];
            size_t x_pos              = tile_size_str.find('x');
            try {
                if (x_pos != std::string::npos) {
                    std::string tile_x_str        = tile_size_str.substr(0, x_pos);
                    std::string tile_y_str        = tile_size_str.substr(x_pos + 1);
                    vae_tiling_params.tile_size_x = std::stoi(tile_x_str);
                    vae_tiling_params.tile_size_y = std::stoi(tile_y_str);
                } else {
                    vae_tiling_params.tile_size_x = vae_tiling_params.tile_size_y = std::stoi(tile_size_str);
                }
            } catch (const std::invalid_argument&) {
                return -1;
            } catch (const std::out_of_range&) {
                return -1;
            }
            return 1;
        };

        auto on_relative_tile_size_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            std::string rel_size_str = argv[index];
            size_t x_pos             = rel_size_str.find('x');
            try {
                if (x_pos != std::string::npos) {
                    std::string rel_x_str        = rel_size_str.substr(0, x_pos);
                    std::string rel_y_str        = rel_size_str.substr(x_pos + 1);
                    vae_tiling_params.rel_size_x = std::stof(rel_x_str);
                    vae_tiling_params.rel_size_y = std::stof(rel_y_str);
                } else {
                    vae_tiling_params.rel_size_x = vae_tiling_params.rel_size_y = std::stof(rel_size_str);
                }
            } catch (const std::invalid_argument&) {
                return -1;
            } catch (const std::out_of_range&) {
                return -1;
            }
            return 1;
        };

        options.manual_options = {
            {"",
             "--type",
             "weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). "
             "If not specified, the default is the type of the weight file",
             on_type_arg},
            {"",
             "--rng",
             "RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)",
             on_rng_arg},
            {"",
             "--sampler-rng",
             "sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng",
             on_sampler_rng_arg},
            {"",
             "--prediction",
             "prediction type override, one of [eps, v, edm_v, sd3_flow, flux_flow, flux2_flow]",
             on_prediction_arg},
            {"",
             "--lora-apply-mode",
             "the way to apply LoRA, one of [auto, immediately, at_runtime], default is auto. "
             "In auto mode, if the model weights contain any quantized parameters, the at_runtime mode will be used; otherwise, immediately will be used."
             "The immediately mode may have precision and compatibility issues with quantized parameters, "
             "but it usually offers faster inference speed and, in some cases, lower memory usage. "
             "The at_runtime mode, on the other hand, is exactly the opposite.",
             on_lora_apply_mode_arg},
            {"",
             "--vae-tile-size",
             "tile size for vae tiling, format [X]x[Y] (default: 32x32)",
             on_tile_size_arg},
            {"",
             "--vae-relative-tile-size",
             "relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)",
             on_relative_tile_size_arg},
        };

        return options;
    }

    void build_embedding_map() {
        static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};

        if (!fs::exists(embedding_dir) || !fs::is_directory(embedding_dir)) {
            return;
        }

        for (auto& p : fs::directory_iterator(embedding_dir)) {
            if (!p.is_regular_file())
                continue;

            auto path       = p.path();
            std::string ext = path.extension().string();

            bool valid = false;
            for (auto& e : valid_ext) {
                if (ext == e) {
                    valid = true;
                    break;
                }
            }
            if (!valid)
                continue;

            std::string key   = path.stem().string();
            std::string value = path.string();

            embedding_map[key] = value;
        }
    }

    bool process_and_check(SDMode mode) {
        if (mode != UPSCALE && model_path.length() == 0 && diffusion_model_path.length() == 0) {
            fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
            return false;
        }

        if (mode == UPSCALE) {
            if (esrgan_path.length() == 0) {
                fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n");
                return false;
            }
        }

        if (n_threads <= 0) {
            n_threads = sd_get_num_physical_cores();
        }

        build_embedding_map();

        return true;
    }

    std::string to_string() const {
        std::ostringstream emb_ss;
        emb_ss << "{\n";
        for (auto it = embedding_map.begin(); it != embedding_map.end(); ++it) {
            emb_ss << "    \"" << it->first << "\": \"" << it->second << "\"";
            if (std::next(it) != embedding_map.end()) {
                emb_ss << ",";
            }
            emb_ss << "\n";
        }
        emb_ss << "  }";

        std::string embeddings_str = emb_ss.str();
        std::ostringstream oss;
        oss << "SDContextParams {\n"
            << "  n_threads: " << n_threads << ",\n"
            << "  model_path: \"" << model_path << "\",\n"
            << "  clip_l_path: \"" << clip_l_path << "\",\n"
            << "  clip_g_path: \"" << clip_g_path << "\",\n"
            << "  clip_vision_path: \"" << clip_vision_path << "\",\n"
            << "  t5xxl_path: \"" << t5xxl_path << "\",\n"
            << "  llm_path: \"" << llm_path << "\",\n"
            << "  llm_vision_path: \"" << llm_vision_path << "\",\n"
            << "  diffusion_model_path: \"" << diffusion_model_path << "\",\n"
            << "  high_noise_diffusion_model_path: \"" << high_noise_diffusion_model_path << "\",\n"
            << "  vae_path: \"" << vae_path << "\",\n"
            << "  taesd_path: \"" << taesd_path << "\",\n"
            << "  esrgan_path: \"" << esrgan_path << "\",\n"
            << "  control_net_path: \"" << control_net_path << "\",\n"
            << "  embedding_dir: \"" << embedding_dir << "\",\n"
            << "  embeddings: " << embeddings_str << "\n"
            << "  wtype: " << sd_type_name(wtype) << ",\n"
            << "  tensor_type_rules: \"" << tensor_type_rules << "\",\n"
            << "  lora_model_dir: \"" << lora_model_dir << "\",\n"
            << "  photo_maker_path: \"" << photo_maker_path << "\",\n"
            << "  rng_type: " << sd_rng_type_name(rng_type) << ",\n"
            << "  sampler_rng_type: " << sd_rng_type_name(sampler_rng_type) << ",\n"
            << "  flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n"
            << "  offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
            << "  control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
            << "  clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
            << "  vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
            << "  diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
            << "  diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
            << "  vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
            << "  chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n"
            << "  chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n"
            << "  chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n"
            << "  prediction: " << sd_prediction_name(prediction) << ",\n"
            << "  lora_apply_mode: " << sd_lora_apply_mode_name(lora_apply_mode) << ",\n"
            << "  vae_tiling_params: { "
            << vae_tiling_params.enabled << ", "
            << vae_tiling_params.tile_size_x << ", "
            << vae_tiling_params.tile_size_y << ", "
            << vae_tiling_params.target_overlap << ", "
            << vae_tiling_params.rel_size_x << ", "
            << vae_tiling_params.rel_size_y << " },\n"
            << "  force_sdxl_vae_conv_scale: " << (force_sdxl_vae_conv_scale ? "true" : "false") << "\n"
            << "}";
        return oss.str();
    }

    sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) {
        embedding_array.clear();
        embedding_array.reserve(embedding_map.size());
        for (const auto& kv : embedding_map) {
            sd_embedding_t item;
            item.name = kv.first.c_str();
            item.path = kv.second.c_str();
            embedding_array.emplace_back(item);
        }

        sd_ctx_params_t sd_ctx_params = {
            model_path.c_str(),
            clip_l_path.c_str(),
            clip_g_path.c_str(),
            clip_vision_path.c_str(),
            t5xxl_path.c_str(),
            llm_path.c_str(),
            llm_vision_path.c_str(),
            diffusion_model_path.c_str(),
            high_noise_diffusion_model_path.c_str(),
            vae_path.c_str(),
            taesd_path.c_str(),
            control_net_path.c_str(),
            lora_model_dir.c_str(),
            embedding_array.data(),
            static_cast<uint32_t>(embedding_array.size()),
            photo_maker_path.c_str(),
            tensor_type_rules.c_str(),
            vae_decode_only,
            free_params_immediately,
            n_threads,
            wtype,
            rng_type,
            sampler_rng_type,
            prediction,
            lora_apply_mode,
            offload_params_to_cpu,
            clip_on_cpu,
            control_net_cpu,
            vae_on_cpu,
            diffusion_flash_attn,
            taesd_preview,
            diffusion_conv_direct,
            vae_conv_direct,
            force_sdxl_vae_conv_scale,
            chroma_use_dit_mask,
            chroma_use_t5_mask,
            chroma_t5_mask_pad,
            flow_shift,
        };
        return sd_ctx_params;
    }
};

template <typename T>
static std::string vec_to_string(const std::vector<T>& v) {
    std::ostringstream oss;
    oss << "[";
    for (size_t i = 0; i < v.size(); i++) {
        oss << v[i];
        if (i + 1 < v.size())
            oss << ", ";
    }
    oss << "]";
    return oss.str();
}

static std::string vec_str_to_string(const std::vector<std::string>& v) {
    std::ostringstream oss;
    oss << "[";
    for (size_t i = 0; i < v.size(); i++) {
        oss << "\"" << v[i] << "\"";
        if (i + 1 < v.size())
            oss << ", ";
    }
    oss << "]";
    return oss.str();
}

struct SDGenerationParams {
    std::string prompt;
    std::string negative_prompt;
    int clip_skip   = -1;  // <= 0 represents unspecified
    int width       = 512;
    int height      = 512;
    int batch_count = 1;
    std::string init_image_path;
    std::string end_image_path;
    std::string mask_image_path;
    std::string control_image_path;
    std::vector<std::string> ref_image_paths;
    std::string control_video_path;
    bool auto_resize_ref_image = true;
    bool increase_ref_index    = false;

    std::vector<int> skip_layers = {7, 8, 9};
    sd_sample_params_t sample_params;

    std::vector<int> high_noise_skip_layers = {7, 8, 9};
    sd_sample_params_t high_noise_sample_params;

    std::string easycache_option;
    sd_easycache_params_t easycache_params;

    float moe_boundary  = 0.875f;
    int video_frames    = 1;
    int fps             = 16;
    float vace_strength = 1.f;

    float strength         = 0.75f;
    float control_strength = 0.9f;

    int64_t seed = 42;

    // Photo Maker
    std::string pm_id_images_dir;
    std::string pm_id_embed_path;
    float pm_style_strength = 20.f;

    int upscale_repeats = 1;

    SDGenerationParams() {
        sd_sample_params_init(&sample_params);
        sd_sample_params_init(&high_noise_sample_params);
    }

    ArgOptions get_options() {
        ArgOptions options;
        options.string_options = {
            {"-p",
             "--prompt",
             "the prompt to render",
             &prompt},
            {"-n",
             "--negative-prompt",
             "the negative prompt (default: \"\")",
             &negative_prompt},
            {"-i",
             "--init-img",
             "path to the init image",
             &init_image_path},
            {"",
             "--end-img",
             "path to the end image, required by flf2v",
             &end_image_path},
            {"",
             "--mask",
             "path to the mask image",
             &mask_image_path},
            {"",
             "--control-image",
             "path to control image, control net",
             &control_image_path},
            {"",
             "--control-video",
             "path to control video frames, It must be a directory path. The video frames inside should be stored as images in "
             "lexicographical (character) order. For example, if the control video path is `frames`, the directory contain images "
             "such as 00.png, 01.png, ... etc.",
             &control_video_path},
            {"",
             "--pm-id-images-dir",
             "path to PHOTOMAKER input id images dir",
             &pm_id_images_dir},
            {"",
             "--pm-id-embed-path",
             "path to PHOTOMAKER v2 id embed",
             &pm_id_embed_path},
        };

        options.int_options = {
            {"-H",
             "--height",
             "image height, in pixel space (default: 512)",
             &height},
            {"-W",
             "--width",
             "image width, in pixel space (default: 512)",
             &width},
            {"",
             "--steps",
             "number of sample steps (default: 20)",
             &sample_params.sample_steps},
            {"",
             "--high-noise-steps",
             "(high noise) number of sample steps (default: -1 = auto)",
             &high_noise_sample_params.sample_steps},
            {"",
             "--clip-skip",
             "ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1). "
             "<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x",
             &clip_skip},
            {"-b",
             "--batch-count",
             "batch count",
             &batch_count},
            {"",
             "--video-frames",
             "video frames (default: 1)",
             &video_frames},
            {"",
             "--fps",
             "fps (default: 24)",
             &fps},
            {"",
             "--timestep-shift",
             "shift timestep for NitroFusion models (default: 0). "
             "recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant",
             &sample_params.shifted_timestep},
            {"",
             "--upscale-repeats",
             "Run the ESRGAN upscaler this many times (default: 1)",
             &upscale_repeats},
        };

        options.float_options = {
            {"",
             "--cfg-scale",
             "unconditional guidance scale: (default: 7.0)",
             &sample_params.guidance.txt_cfg},
            {"",
             "--img-cfg-scale",
             "image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)",
             &sample_params.guidance.img_cfg},
            {"",
             "--guidance",
             "distilled guidance scale for models with guidance input (default: 3.5)",
             &sample_params.guidance.distilled_guidance},
            {"",
             "--slg-scale",
             "skip layer guidance (SLG) scale, only for DiT models: (default: 0). 0 means disabled, a value of 2.5 is nice for sd3.5 medium",
             &sample_params.guidance.slg.scale},
            {"",
             "--skip-layer-start",
             "SLG enabling point (default: 0.01)",
             &sample_params.guidance.slg.layer_start},
            {"",
             "--skip-layer-end",
             "SLG disabling point (default: 0.2)",
             &sample_params.guidance.slg.layer_end},
            {"",
             "--eta",
             "eta in DDIM, only for DDIM and TCD (default: 0)",
             &sample_params.eta},
            {"",
             "--high-noise-cfg-scale",
             "(high noise) unconditional guidance scale: (default: 7.0)",
             &high_noise_sample_params.guidance.txt_cfg},
            {"",
             "--high-noise-img-cfg-scale",
             "(high noise) image guidance scale for inpaint or instruct-pix2pix models (default: same as --cfg-scale)",
             &high_noise_sample_params.guidance.img_cfg},
            {"",
             "--high-noise-guidance",
             "(high noise) distilled guidance scale for models with guidance input (default: 3.5)",
             &high_noise_sample_params.guidance.distilled_guidance},
            {"",
             "--high-noise-slg-scale",
             "(high noise) skip layer guidance (SLG) scale, only for DiT models: (default: 0)",
             &high_noise_sample_params.guidance.slg.scale},
            {"",
             "--high-noise-skip-layer-start",
             "(high noise) SLG enabling point (default: 0.01)",
             &high_noise_sample_params.guidance.slg.layer_start},
            {"",
             "--high-noise-skip-layer-end",
             "(high noise) SLG disabling point (default: 0.2)",
             &high_noise_sample_params.guidance.slg.layer_end},
            {"",
             "--high-noise-eta",
             "(high noise) eta in DDIM, only for DDIM and TCD (default: 0)",
             &high_noise_sample_params.eta},
            {"",
             "--strength",
             "strength for noising/unnoising (default: 0.75)",
             &strength},
            {"",
             "--pm-style-strength",
             "",
             &pm_style_strength},
            {"",
             "--control-strength",
             "strength to apply Control Net (default: 0.9). 1.0 corresponds to full destruction of information in init image",
             &control_strength},
            {"",
             "--moe-boundary",
             "timestep boundary for Wan2.2 MoE model. (default: 0.875). Only enabled if `--high-noise-steps` is set to -1",
             &moe_boundary},
            {"",
             "--vace-strength",
             "wan vace strength",
             &vace_strength},
        };

        options.bool_options = {
            {"",
             "--increase-ref-index",
             "automatically increase the indices of references images based on the order they are listed (starting with 1).",
             true,
             &increase_ref_index},
            {"",
             "--disable-auto-resize-ref-image",
             "disable auto resize of ref images",
             false,
             &auto_resize_ref_image},
        };

        auto on_seed_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            seed = std::stoll(argv[index]);
            return 1;
        };

        auto on_sample_method_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            const char* arg             = argv[index];
            sample_params.sample_method = str_to_sample_method(arg);
            if (sample_params.sample_method == SAMPLE_METHOD_COUNT) {
                fprintf(stderr, "error: invalid sample method %s\n",
                        arg);
                return -1;
            }
            return 1;
        };

        auto on_high_noise_sample_method_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            const char* arg                        = argv[index];
            high_noise_sample_params.sample_method = str_to_sample_method(arg);
            if (high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
                fprintf(stderr, "error: invalid high noise sample method %s\n",
                        arg);
                return -1;
            }
            return 1;
        };

        auto on_scheduler_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            const char* arg         = argv[index];
            sample_params.scheduler = str_to_scheduler(arg);
            if (sample_params.scheduler == SCHEDULER_COUNT) {
                fprintf(stderr, "error: invalid scheduler %s\n",
                        arg);
                return -1;
            }
            return 1;
        };

        auto on_skip_layers_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            std::string layers_str = argv[index];
            if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') {
                return -1;
            }

            layers_str = layers_str.substr(1, layers_str.size() - 2);

            std::regex regex("[, ]+");
            std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1);
            std::sregex_token_iterator end;
            std::vector<std::string> tokens(iter, end);
            std::vector<int> layers;
            for (const auto& token : tokens) {
                try {
                    layers.push_back(std::stoi(token));
                } catch (const std::invalid_argument&) {
                    return -1;
                }
            }
            skip_layers = layers;
            return 1;
        };

        auto on_high_noise_skip_layers_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            std::string layers_str = argv[index];
            if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') {
                return -1;
            }

            layers_str = layers_str.substr(1, layers_str.size() - 2);

            std::regex regex("[, ]+");
            std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1);
            std::sregex_token_iterator end;
            std::vector<std::string> tokens(iter, end);
            std::vector<int> layers;
            for (const auto& token : tokens) {
                try {
                    layers.push_back(std::stoi(token));
                } catch (const std::invalid_argument&) {
                    return -1;
                }
            }
            high_noise_skip_layers = layers;
            return 1;
        };

        auto on_ref_image_arg = [&](int argc, const char** argv, int index) {
            if (++index >= argc) {
                return -1;
            }
            ref_image_paths.push_back(argv[index]);
            return 1;
        };

        auto on_easycache_arg = [&](int argc, const char** argv, int index) {
            const std::string default_values = "0.2,0.15,0.95";
            auto looks_like_value            = [](const std::string& token) {
                if (token.empty()) {
                    return false;
                }
                if (token[0] != '-') {
                    return true;
                }
                if (token.size() == 1) {
                    return false;
                }
                unsigned char next = static_cast<unsigned char>(token[1]);
                return std::isdigit(next) || token[1] == '.';
            };

            std::string option_value;
            int consumed = 0;
            if (index + 1 < argc) {
                std::string next_arg = argv[index + 1];
                if (looks_like_value(next_arg)) {
                    option_value = argv_to_utf8(index + 1, argv);
                    consumed     = 1;
                }
            }
            if (option_value.empty()) {
                option_value = default_values;
            }
            easycache_option = option_value;
            return consumed;
        };

        options.manual_options = {
            {"-s",
             "--seed",
             "RNG seed (default: 42, use random seed for < 0)",
             on_seed_arg},
            {"",
             "--sampling-method",
             "sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd] "
             "(default: euler for Flux/SD3/Wan, euler_a otherwise)",
             on_sample_method_arg},
            {"",
             "--high-noise-sampling-method",
             "(high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd]"
             " default: euler for Flux/SD3/Wan, euler_a otherwise",
             on_high_noise_sample_method_arg},
            {"",
             "--scheduler",
             "denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete",
             on_scheduler_arg},
            {"",
             "--skip-layers",
             "layers to skip for SLG steps (default: [7,8,9])",
             on_skip_layers_arg},
            {"",
             "--high-noise-skip-layers",
             "(high noise) layers to skip for SLG steps (default: [7,8,9])",
             on_high_noise_skip_layers_arg},
            {"-r",
             "--ref-image",
             "reference image for Flux Kontext models (can be used multiple times)",
             on_ref_image_arg},
            {"",
             "--easycache",
             "enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)",
             on_easycache_arg},

        };

        return options;
    }

    bool process_and_check(SDMode mode) {
        if (width <= 0) {
            fprintf(stderr, "error: the width must be greater than 0\n");
            return false;
        }

        if (height <= 0) {
            fprintf(stderr, "error: the height must be greater than 0\n");
            return false;
        }

        if (sample_params.sample_steps <= 0) {
            fprintf(stderr, "error: the sample_steps must be greater than 0\n");
            return false;
        }

        if (high_noise_sample_params.sample_steps <= 0) {
            high_noise_sample_params.sample_steps = -1;
        }

        if (strength < 0.f || strength > 1.f) {
            fprintf(stderr, "error: can only work with strength in [0.0, 1.0]\n");
            return false;
        }

        if (!easycache_option.empty()) {
            float values[3] = {0.0f, 0.0f, 0.0f};
            std::stringstream ss(easycache_option);
            std::string token;
            int idx = 0;
            while (std::getline(ss, token, ',')) {
                auto trim = [](std::string& s) {
                    const char* whitespace = " \t\r\n";
                    auto start             = s.find_first_not_of(whitespace);
                    if (start == std::string::npos) {
                        s.clear();
                        return;
                    }
                    auto end = s.find_last_not_of(whitespace);
                    s        = s.substr(start, end - start + 1);
                };
                trim(token);
                if (token.empty()) {
                    fprintf(stderr, "error: invalid easycache option '%s'\n", easycache_option.c_str());
                    return false;
                }
                if (idx >= 3) {
                    fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
                    return false;
                }
                try {
                    values[idx] = std::stof(token);
                } catch (const std::exception&) {
                    fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str());
                    return false;
                }
                idx++;
            }
            if (idx != 3) {
                fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
                return false;
            }
            if (values[0] < 0.0f) {
                fprintf(stderr, "error: easycache threshold must be non-negative\n");
                return false;
            }
            if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
                fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
                return false;
            }
            easycache_params.enabled         = true;
            easycache_params.reuse_threshold = values[0];
            easycache_params.start_percent   = values[1];
            easycache_params.end_percent     = values[2];
        } else {
            easycache_params.enabled = false;
        }

        sample_params.guidance.slg.layers                 = skip_layers.data();
        sample_params.guidance.slg.layer_count            = skip_layers.size();
        high_noise_sample_params.guidance.slg.layers      = high_noise_skip_layers.data();
        high_noise_sample_params.guidance.slg.layer_count = high_noise_skip_layers.size();

        if (mode == VID_GEN && video_frames <= 0) {
            return false;
        }

        if (mode == VID_GEN && fps <= 0) {
            return false;
        }

        if (sample_params.shifted_timestep < 0 || sample_params.shifted_timestep > 1000) {
            return false;
        }

        if (upscale_repeats < 1) {
            return false;
        }

        if (mode == UPSCALE) {
            if (init_image_path.length() == 0) {
                fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n");
                return false;
            }
        }

        if (seed < 0) {
            srand((int)time(nullptr));
            seed = rand();
        }

        return true;
    }

    std::string to_string() const {
        char* sample_params_str            = sd_sample_params_to_str(&sample_params);
        char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params);
        std::ostringstream oss;
        oss << "SDGenerationParams {\n"
            << "  prompt: \"" << prompt << "\",\n"
            << "  negative_prompt: \"" << negative_prompt << "\",\n"
            << "  clip_skip: " << clip_skip << ",\n"
            << "  width: " << width << ",\n"
            << "  height: " << height << ",\n"
            << "  batch_count: " << batch_count << ",\n"
            << "  init_image_path: \"" << init_image_path << "\",\n"
            << "  end_image_path: \"" << end_image_path << "\",\n"
            << "  mask_image_path: \"" << mask_image_path << "\",\n"
            << "  control_image_path: \"" << control_image_path << "\",\n"
            << "  ref_image_paths: " << vec_str_to_string(ref_image_paths) << ",\n"
            << "  control_video_path: \"" << control_video_path << "\",\n"
            << "  auto_resize_ref_image: " << (auto_resize_ref_image ? "true" : "false") << ",\n"
            << "  increase_ref_index: " << (increase_ref_index ? "true" : "false") << ",\n"
            << "  pm_id_images_dir: \"" << pm_id_images_dir << "\",\n"
            << "  pm_id_embed_path: \"" << pm_id_embed_path << "\",\n"
            << "  pm_style_strength: " << pm_style_strength << ",\n"
            << "  skip_layers: " << vec_to_string(skip_layers) << ",\n"
            << "  sample_params: " << sample_params_str << ",\n"
            << "  high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
            << "  high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
            << "  easycache_option: \"" << easycache_option << "\",\n"
            << "  easycache: "
            << (easycache_params.enabled ? "enabled" : "disabled")
            << " (threshold=" << easycache_params.reuse_threshold
            << ", start=" << easycache_params.start_percent
            << ", end=" << easycache_params.end_percent << "),\n"
            << "  moe_boundary: " << moe_boundary << ",\n"
            << "  video_frames: " << video_frames << ",\n"
            << "  fps: " << fps << ",\n"
            << "  vace_strength: " << vace_strength << ",\n"
            << "  strength: " << strength << ",\n"
            << "  control_strength: " << control_strength << ",\n"
            << "  seed: " << seed << ",\n"
            << "  upscale_repeats: " << upscale_repeats << ",\n"
            << "}";
        free(sample_params_str);
        free(high_noise_sample_params_str);
        return oss.str();
    }
};

void print_usage(int argc, const char* argv[], const std::vector<ArgOptions>& options_list) {
    std::cout << "Usage: " << argv[0] << " [options]\n\n";
    std::cout << "CLI Options:\n";
    options_list[0].print();
    std::cout << "\nContext Options:\n";
    options_list[1].print();
    std::cout << "\nGeneration Options:\n";
    options_list[2].print();
}

void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextParams& ctx_params, SDGenerationParams& gen_params) {
    std::vector<ArgOptions> options_vec = {cli_params.get_options(), ctx_params.get_options(), gen_params.get_options()};

    if (!parse_options(argc, argv, options_vec)) {
        print_usage(argc, argv, options_vec);
        exit(cli_params.normal_exit ? 0 : 1);
    }

    if (!cli_params.process_and_check() || !ctx_params.process_and_check(cli_params.mode) || !gen_params.process_and_check(cli_params.mode)) {
        print_usage(argc, argv, options_vec);
        exit(1);
    }
}

static std::string sd_basename(const std::string& path) {
    size_t pos = path.find_last_of('/');
    if (pos != std::string::npos) {
        return path.substr(pos + 1);
    }
    pos = path.find_last_of('\\');
    if (pos != std::string::npos) {
        return path.substr(pos + 1);
    }
    return path;
}

std::string get_image_params(const SDCliParams& cli_params, const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed) {
    std::string parameter_string = gen_params.prompt + "\n";
    if (gen_params.negative_prompt.size() != 0) {
        parameter_string += "Negative prompt: " + gen_params.negative_prompt + "\n";
    }
    parameter_string += "Steps: " + std::to_string(gen_params.sample_params.sample_steps) + ", ";
    parameter_string += "CFG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", ";
    if (gen_params.sample_params.guidance.slg.scale != 0 && gen_params.skip_layers.size() != 0) {
        parameter_string += "SLG scale: " + std::to_string(gen_params.sample_params.guidance.txt_cfg) + ", ";
        parameter_string += "Skip layers: [";
        for (const auto& layer : gen_params.skip_layers) {
            parameter_string += std::to_string(layer) + ", ";
        }
        parameter_string += "], ";
        parameter_string += "Skip layer start: " + std::to_string(gen_params.sample_params.guidance.slg.layer_start) + ", ";
        parameter_string += "Skip layer end: " + std::to_string(gen_params.sample_params.guidance.slg.layer_end) + ", ";
    }
    parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
    parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
    parameter_string += "Seed: " + std::to_string(seed) + ", ";
    parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", ";
    parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
    parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
    if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
        parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(ctx_params.sampler_rng_type)) + ", ";
    }
    parameter_string += "Sampler: " + std::string(sd_sample_method_name(gen_params.sample_params.sample_method));
    if (gen_params.sample_params.scheduler != SCHEDULER_COUNT) {
        parameter_string += " " + std::string(sd_scheduler_name(gen_params.sample_params.scheduler));
    }
    parameter_string += ", ";
    for (const auto& te : {ctx_params.clip_l_path, ctx_params.clip_g_path, ctx_params.t5xxl_path, ctx_params.llm_path, ctx_params.llm_vision_path}) {
        if (!te.empty()) {
            parameter_string += "TE: " + sd_basename(te) + ", ";
        }
    }
    if (!ctx_params.diffusion_model_path.empty()) {
        parameter_string += "Unet: " + sd_basename(ctx_params.diffusion_model_path) + ", ";
    }
    if (!ctx_params.vae_path.empty()) {
        parameter_string += "VAE: " + sd_basename(ctx_params.vae_path) + ", ";
    }
    if (gen_params.clip_skip != -1) {
        parameter_string += "Clip skip: " + std::to_string(gen_params.clip_skip) + ", ";
    }
    parameter_string += "Version: stable-diffusion.cpp";
    return parameter_string;
}

/* Enables Printing the log level tag in color using ANSI escape codes */
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
    SDCliParams* cli_params = (SDCliParams*)data;
    int tag_color;
    const char* level_str;
    FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout;

    if (!log || (!cli_params->verbose && level <= SD_LOG_DEBUG)) {
        return;
    }

    switch (level) {
        case SD_LOG_DEBUG:
            tag_color = 37;
            level_str = "DEBUG";
            break;
        case SD_LOG_INFO:
            tag_color = 34;
            level_str = "INFO";
            break;
        case SD_LOG_WARN:
            tag_color = 35;
            level_str = "WARN";
            break;
        case SD_LOG_ERROR:
            tag_color = 31;
            level_str = "ERROR";
            break;
        default: /* Potential future-proofing */
            tag_color = 33;
            level_str = "?????";
            break;
    }

    if (cli_params->color == true) {
        fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str);
    } else {
        fprintf(out_stream, "[%-5s] ", level_str);
    }
    fputs(log, out_stream);
    fflush(out_stream);
}

uint8_t* load_image(const char* image_path, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3) {
    int c                 = 0;
    uint8_t* image_buffer = (uint8_t*)stbi_load(image_path, &width, &height, &c, expected_channel);
    if (image_buffer == nullptr) {
        fprintf(stderr, "load image from '%s' failed\n", image_path);
        return nullptr;
    }
    if (c < expected_channel) {
        fprintf(stderr,
                "the number of channels for the input image must be >= %d,"
                "but got %d channels, image_path = %s\n",
                expected_channel,
                c,
                image_path);
        free(image_buffer);
        return nullptr;
    }
    if (width <= 0) {
        fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path);
        free(image_buffer);
        return nullptr;
    }
    if (height <= 0) {
        fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path);
        free(image_buffer);
        return nullptr;
    }

    // Resize input image ...
    if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) {
        float dst_aspect = (float)expected_width / (float)expected_height;
        float src_aspect = (float)width / (float)height;

        int crop_x = 0, crop_y = 0;
        int crop_w = width, crop_h = height;

        if (src_aspect > dst_aspect) {
            crop_w = (int)(height * dst_aspect);
            crop_x = (width - crop_w) / 2;
        } else if (src_aspect < dst_aspect) {
            crop_h = (int)(width / dst_aspect);
            crop_y = (height - crop_h) / 2;
        }

        if (crop_x != 0 || crop_y != 0) {
            printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path);
            uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel);
            if (cropped_image_buffer == nullptr) {
                fprintf(stderr, "error: allocate memory for crop\n");
                free(image_buffer);
                return nullptr;
            }
            for (int row = 0; row < crop_h; row++) {
                uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel;
                uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel;
                memcpy(dst, src, crop_w * expected_channel);
            }

            width  = crop_w;
            height = crop_h;
            free(image_buffer);
            image_buffer = cropped_image_buffer;
        }

        printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height);
        int resized_height = expected_height;
        int resized_width  = expected_width;

        uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel);
        if (resized_image_buffer == nullptr) {
            fprintf(stderr, "error: allocate memory for resize input image\n");
            free(image_buffer);
            return nullptr;
        }
        stbir_resize(image_buffer, width, height, 0,
                     resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
                     expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0,
                     STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
                     STBIR_FILTER_BOX, STBIR_FILTER_BOX,
                     STBIR_COLORSPACE_SRGB, nullptr);
        width  = resized_width;
        height = resized_height;
        free(image_buffer);
        image_buffer = resized_image_buffer;
    }
    return image_buffer;
}

bool load_images_from_dir(const std::string dir,
                          std::vector<sd_image_t>& images,
                          int expected_width  = 0,
                          int expected_height = 0,
                          int max_image_num   = 0,
                          bool verbose        = false) {
    if (!fs::exists(dir) || !fs::is_directory(dir)) {
        fprintf(stderr, "'%s' is not a valid directory\n", dir.c_str());
        return false;
    }

    std::vector<fs::directory_entry> entries;
    for (const auto& entry : fs::directory_iterator(dir)) {
        if (entry.is_regular_file()) {
            entries.push_back(entry);
        }
    }

    std::sort(entries.begin(), entries.end(),
              [](const fs::directory_entry& a, const fs::directory_entry& b) {
                  return a.path().filename().string() < b.path().filename().string();
              });

    for (const auto& entry : entries) {
        std::string path = entry.path().string();
        std::string ext  = entry.path().extension().string();
        std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);

        if (ext == ".jpg" || ext == ".jpeg" || ext == ".png" || ext == ".bmp") {
            if (verbose) {
                printf("load image %zu from '%s'\n", images.size(), path.c_str());
            }
            int width             = 0;
            int height            = 0;
            uint8_t* image_buffer = load_image(path.c_str(), width, height, expected_width, expected_height);
            if (image_buffer == nullptr) {
                fprintf(stderr, "load image from '%s' failed\n", path.c_str());
                return false;
            }

            images.push_back({(uint32_t)width,
                              (uint32_t)height,
                              3,
                              image_buffer});

            if (max_image_num > 0 && images.size() >= max_image_num) {
                break;
            }
        }
    }
    return true;
}

void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy, void* data) {
    (void)step;
    (void)is_noisy;
    SDCliParams* cli_params = (SDCliParams*)data;
    // is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents
    // unused in this app, it will either be always noisy or always denoised here
    if (frame_count == 1) {
        stbi_write_png(cli_params->preview_path.c_str(), image->width, image->height, image->channel, image->data, 0);
    } else {
        create_mjpg_avi_from_sd_images(cli_params->preview_path.c_str(), image, frame_count, cli_params->preview_fps);
    }
}

int main(int argc, const char* argv[]) {
    SDCliParams cli_params;
    SDContextParams ctx_params;
    SDGenerationParams gen_params;

    parse_args(argc, argv, cli_params, ctx_params, gen_params);
    if (gen_params.video_frames > 4) {
        size_t last_dot_pos   = cli_params.preview_path.find_last_of(".");
        std::string base_path = cli_params.preview_path;
        std::string file_ext  = "";
        if (last_dot_pos != std::string::npos) {  // filename has extension
            base_path = cli_params.preview_path.substr(0, last_dot_pos);
            file_ext  = cli_params.preview_path.substr(last_dot_pos);
            std::transform(file_ext.begin(), file_ext.end(), file_ext.begin(), ::tolower);
        }
        if (file_ext == ".png") {
            cli_params.preview_path = base_path + ".avi";
        }
    }
    cli_params.preview_fps = gen_params.fps;
    if (cli_params.preview_method == PREVIEW_PROJ)
        cli_params.preview_fps /= 4;

    sd_set_log_callback(sd_log_cb, (void*)&cli_params);
    sd_set_preview_callback(step_callback,
                            cli_params.preview_method,
                            cli_params.preview_interval,
                            !cli_params.preview_noisy,
                            cli_params.preview_noisy,
                            (void*)&cli_params);

    if (cli_params.verbose) {
        printf("%s", sd_get_system_info());
        printf("%s\n", cli_params.to_string().c_str());
        printf("%s\n", ctx_params.to_string().c_str());
        printf("%s\n", gen_params.to_string().c_str());
    }

    if (cli_params.mode == CONVERT) {
        bool success = convert(ctx_params.model_path.c_str(),
                               ctx_params.vae_path.c_str(),
                               cli_params.output_path.c_str(),
                               ctx_params.wtype,
                               ctx_params.tensor_type_rules.c_str());
        if (!success) {
            fprintf(stderr,
                    "convert '%s'/'%s' to '%s' failed\n",
                    ctx_params.model_path.c_str(),
                    ctx_params.vae_path.c_str(),
                    cli_params.output_path.c_str());
            return 1;
        } else {
            printf("convert '%s'/'%s' to '%s' success\n",
                   ctx_params.model_path.c_str(),
                   ctx_params.vae_path.c_str(),
                   cli_params.output_path.c_str());
            return 0;
        }
    }

    bool vae_decode_only     = true;
    sd_image_t init_image    = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
    sd_image_t end_image     = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
    sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
    sd_image_t mask_image    = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
    std::vector<sd_image_t> ref_images;
    std::vector<sd_image_t> pmid_images;
    std::vector<sd_image_t> control_frames;

    auto release_all_resources = [&]() {
        free(init_image.data);
        free(end_image.data);
        free(control_image.data);
        free(mask_image.data);
        for (auto image : ref_images) {
            free(image.data);
            image.data = nullptr;
        }
        ref_images.clear();
        for (auto image : pmid_images) {
            free(image.data);
            image.data = nullptr;
        }
        pmid_images.clear();
        for (auto image : control_frames) {
            free(image.data);
            image.data = nullptr;
        }
        control_frames.clear();
    };

    if (gen_params.init_image_path.size() > 0) {
        vae_decode_only = false;

        int width       = 0;
        int height      = 0;
        init_image.data = load_image(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
        if (init_image.data == nullptr) {
            fprintf(stderr, "load image from '%s' failed\n", gen_params.init_image_path.c_str());
            release_all_resources();
            return 1;
        }
    }

    if (gen_params.end_image_path.size() > 0) {
        vae_decode_only = false;

        int width      = 0;
        int height     = 0;
        end_image.data = load_image(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
        if (end_image.data == nullptr) {
            fprintf(stderr, "load image from '%s' failed\n", gen_params.end_image_path.c_str());
            release_all_resources();
            return 1;
        }
    }

    if (gen_params.mask_image_path.size() > 0) {
        int c           = 0;
        int width       = 0;
        int height      = 0;
        mask_image.data = load_image(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
        if (mask_image.data == nullptr) {
            fprintf(stderr, "load image from '%s' failed\n", gen_params.mask_image_path.c_str());
            release_all_resources();
            return 1;
        }
    } else {
        mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height);
        memset(mask_image.data, 255, gen_params.width * gen_params.height);
        if (mask_image.data == nullptr) {
            fprintf(stderr, "malloc mask image failed\n");
            release_all_resources();
            return 1;
        }
    }

    if (gen_params.control_image_path.size() > 0) {
        int width          = 0;
        int height         = 0;
        control_image.data = load_image(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
        if (control_image.data == nullptr) {
            fprintf(stderr, "load image from '%s' failed\n", gen_params.control_image_path.c_str());
            release_all_resources();
            return 1;
        }
        if (cli_params.canny_preprocess) {  // apply preprocessor
            preprocess_canny(control_image,
                             0.08f,
                             0.08f,
                             0.8f,
                             1.0f,
                             false);
        }
    }

    if (gen_params.ref_image_paths.size() > 0) {
        vae_decode_only = false;
        for (auto& path : gen_params.ref_image_paths) {
            int width             = 0;
            int height            = 0;
            uint8_t* image_buffer = load_image(path.c_str(), width, height);
            if (image_buffer == nullptr) {
                fprintf(stderr, "load image from '%s' failed\n", path.c_str());
                release_all_resources();
                return 1;
            }
            ref_images.push_back({(uint32_t)width,
                                  (uint32_t)height,
                                  3,
                                  image_buffer});
        }
    }

    if (!gen_params.control_video_path.empty()) {
        if (!load_images_from_dir(gen_params.control_video_path,
                                  control_frames,
                                  gen_params.width,
                                  gen_params.height,
                                  gen_params.video_frames,
                                  cli_params.verbose)) {
            release_all_resources();
            return 1;
        }
    }

    if (!gen_params.pm_id_images_dir.empty()) {
        if (!load_images_from_dir(gen_params.pm_id_images_dir,
                                  pmid_images,
                                  0,
                                  0,
                                  0,
                                  cli_params.verbose)) {
            release_all_resources();
            return 1;
        }
    }

    if (cli_params.mode == VID_GEN) {
        vae_decode_only = false;
    }

    sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);

    sd_image_t* results = nullptr;
    int num_results     = 0;

    if (cli_params.mode == UPSCALE) {
        num_results = 1;
        results     = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
        if (results == nullptr) {
            printf("failed to allocate results array\n");
            release_all_resources();
            return 1;
        }

        results[0]      = init_image;
        init_image.data = nullptr;
    } else {
        sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);

        if (sd_ctx == nullptr) {
            printf("new_sd_ctx_t failed\n");
            release_all_resources();
            return 1;
        }

        if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
            gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
        }

        if (gen_params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
            gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
        }

        if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
            gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
        }

        if (cli_params.mode == IMG_GEN) {
            sd_img_gen_params_t img_gen_params = {
                gen_params.prompt.c_str(),
                gen_params.negative_prompt.c_str(),
                gen_params.clip_skip,
                init_image,
                ref_images.data(),
                (int)ref_images.size(),
                gen_params.auto_resize_ref_image,
                gen_params.increase_ref_index,
                mask_image,
                gen_params.width,
                gen_params.height,
                gen_params.sample_params,
                gen_params.strength,
                gen_params.seed,
                gen_params.batch_count,
                control_image,
                gen_params.control_strength,
                {
                    pmid_images.data(),
                    (int)pmid_images.size(),
                    gen_params.pm_id_embed_path.c_str(),
                    gen_params.pm_style_strength,
                },  // pm_params
                ctx_params.vae_tiling_params,
                gen_params.easycache_params,
            };

            results     = generate_image(sd_ctx, &img_gen_params);
            num_results = gen_params.batch_count;
        } else if (cli_params.mode == VID_GEN) {
            sd_vid_gen_params_t vid_gen_params = {
                gen_params.prompt.c_str(),
                gen_params.negative_prompt.c_str(),
                gen_params.clip_skip,
                init_image,
                end_image,
                control_frames.data(),
                (int)control_frames.size(),
                gen_params.width,
                gen_params.height,
                gen_params.sample_params,
                gen_params.high_noise_sample_params,
                gen_params.moe_boundary,
                gen_params.strength,
                gen_params.seed,
                gen_params.video_frames,
                gen_params.vace_strength,
                gen_params.easycache_params,
            };

            results = generate_video(sd_ctx, &vid_gen_params, &num_results);
        }

        if (results == nullptr) {
            printf("generate failed\n");
            free_sd_ctx(sd_ctx);
            return 1;
        }

        free_sd_ctx(sd_ctx);
    }

    int upscale_factor = 4;  // unused for RealESRGAN_x4plus_anime_6B.pth
    if (ctx_params.esrgan_path.size() > 0 && gen_params.upscale_repeats > 0) {
        upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
                                                        ctx_params.offload_params_to_cpu,
                                                        ctx_params.diffusion_conv_direct,
                                                        ctx_params.n_threads);

        if (upscaler_ctx == nullptr) {
            printf("new_upscaler_ctx failed\n");
        } else {
            for (int i = 0; i < num_results; i++) {
                if (results[i].data == nullptr) {
                    continue;
                }
                sd_image_t current_image = results[i];
                for (int u = 0; u < gen_params.upscale_repeats; ++u) {
                    sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor);
                    if (upscaled_image.data == nullptr) {
                        printf("upscale failed\n");
                        break;
                    }
                    free(current_image.data);
                    current_image = upscaled_image;
                }
                results[i] = current_image;  // Set the final upscaled image as the result
            }
        }
    }

    // create directory if not exists
    {
        const fs::path out_path = cli_params.output_path;
        if (const fs::path out_dir = out_path.parent_path(); !out_dir.empty()) {
            std::error_code ec;
            fs::create_directories(out_dir, ec);  // OK if already exists
            if (ec) {
                fprintf(stderr, "failed to create directory '%s': %s\n",
                        out_dir.string().c_str(), ec.message().c_str());
                return 1;
            }
        }
    }

    std::string base_path;
    std::string file_ext;
    std::string file_ext_lower;
    bool is_jpg;
    size_t last_dot_pos   = cli_params.output_path.find_last_of(".");
    size_t last_slash_pos = std::min(cli_params.output_path.find_last_of("/"),
                                     cli_params.output_path.find_last_of("\\"));
    if (last_dot_pos != std::string::npos && (last_slash_pos == std::string::npos || last_dot_pos > last_slash_pos)) {  // filename has extension
        base_path = cli_params.output_path.substr(0, last_dot_pos);
        file_ext = file_ext_lower = cli_params.output_path.substr(last_dot_pos);
        std::transform(file_ext.begin(), file_ext.end(), file_ext_lower.begin(), ::tolower);
        is_jpg = (file_ext_lower == ".jpg" || file_ext_lower == ".jpeg" || file_ext_lower == ".jpe");
    } else {
        base_path = cli_params.output_path;
        file_ext = file_ext_lower = "";
        is_jpg                    = false;
    }

    if (cli_params.mode == VID_GEN && num_results > 1) {
        std::string vid_output_path = cli_params.output_path;
        if (file_ext_lower == ".png") {
            vid_output_path = base_path + ".avi";
        }
        create_mjpg_avi_from_sd_images(vid_output_path.c_str(), results, num_results, gen_params.fps);
        printf("save result MJPG AVI video to '%s'\n", vid_output_path.c_str());
    } else {
        // appending ".png" to absent or unknown extension
        if (!is_jpg && file_ext_lower != ".png") {
            base_path += file_ext;
            file_ext = ".png";
        }
        for (int i = 0; i < num_results; i++) {
            if (results[i].data == nullptr) {
                continue;
            }
            int write_ok;
            std::string final_image_path = i > 0 ? base_path + "_" + std::to_string(i + 1) + file_ext : base_path + file_ext;
            if (is_jpg) {
                write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
                                          results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
                printf("save result JPEG image to '%s' (%s)\n", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
            } else {
                write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
                                          results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
                printf("save result PNG image to '%s' (%s)\n", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
            }
        }
    }

    for (int i = 0; i < num_results; i++) {
        free(results[i].data);
        results[i].data = nullptr;
    }
    free(results);

    release_all_resources();

    return 0;
}
